{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates speech command recognition from a microphone's stream in NeMo.\n",
    "\n",
    "It is **not a recommended** way to do inference in production workflows. If you are interested in \n",
    "production-level inference using NeMo ASR models, please sign-up to Jarvis early access program: https://developer.nvidia.com/nvidia-jarvis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The notebook requires PyAudio library to get a signal from an audio device.\n",
    "For Ubuntu, please run the following commands to install it:\n",
    "```\n",
    "sudo apt-get install -y portaudio19-dev\n",
    "pip install pyaudio\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr\n",
    "import numpy as np\n",
    "import pyaudio as pa\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Architecture and Weights\n",
    "\n",
    "The model architecture is defined in a YAML file available in the config directory. MatchboxNet 3x1x64 has been trained on the Google Speech Commands dataset (v2) version, and these weights are available on NGC. They will automatically be downloaded if not found."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the checkpoints are available from NGC: https://ngc.nvidia.com/catalog/models/nvidia:google_speech_commands_v2___matchboxnet_3x1x1\n",
    "MODEL_YAML = '../configs/quartznet_speech_commands_3x1_v2.yaml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the checkpoint files\n",
    "base_checkpoint_path = './checkpoints/matchboxnet_v2-3x1x64/'\n",
    "CHECKPOINT_ENCODER = os.path.join(base_checkpoint_path, 'JasperEncoder-STEP-89000.pt')\n",
    "CHECKPOINT_DECODER = os.path.join(base_checkpoint_path, 'JasperDecoderForClassification-STEP-89000.pt')\n",
    "\n",
    "if not os.path.exists(base_checkpoint_path):\n",
    "    os.makedirs(base_checkpoint_path)\n",
    "    \n",
    "if not os.path.exists(CHECKPOINT_ENCODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/google_speech_commands_v2___matchboxnet_3x1x1/versions/1/files/JasperEncoder-STEP-89000.pt -P {base_checkpoint_path};\n",
    "\n",
    "if not os.path.exists(CHECKPOINT_DECODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/google_speech_commands_v2___matchboxnet_3x1x1/versions/1/files/JasperDecoderForClassification-STEP-89000.pt -P {base_checkpoint_path};"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct the Neural Modules and the eval graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ruamel.yaml import YAML\n",
    "yaml = YAML(typ=\"safe\")\n",
    "with open(MODEL_YAML) as f:\n",
    "    model_definition = yaml.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neural_factory = nemo.core.NeuralModuleFactory(\n",
    "    placement=nemo.core.DeviceType.GPU,\n",
    "    backend=nemo.core.Backend.PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a Neural Module to iterate over audio\n",
    "\n",
    "Here we define a custom Neural Module which acts as an iterator over a stream of audio that is supplied to it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.backends.pytorch.nm import DataLayerNM\n",
    "from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType\n",
    "import torch\n",
    "\n",
    "# simple data layer to pass audio signal\n",
    "class AudioDataLayer(DataLayerNM):\n",
    "    @property\n",
    "    def output_ports(self):\n",
    "        return {\n",
    "            'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),\n",
    "            'a_sig_length': NeuralType(tuple('B'), LengthsType()),\n",
    "        }\n",
    "\n",
    "    def __init__(self, sample_rate):\n",
    "        super().__init__()\n",
    "        self._sample_rate = sample_rate\n",
    "        self.output = True\n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if not self.output:\n",
    "            raise StopIteration\n",
    "        self.output = False\n",
    "        return torch.as_tensor(self.signal, dtype=torch.float32), \\\n",
    "               torch.as_tensor(self.signal_shape, dtype=torch.int64)\n",
    "        \n",
    "    def set_signal(self, signal):\n",
    "        self.signal = np.reshape(signal.astype(np.float32)/32768., [1, -1])\n",
    "        self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64)\n",
    "        self.output = True\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    @property\n",
    "    def dataset(self):\n",
    "        return None\n",
    "\n",
    "    @property\n",
    "    def data_iterator(self):\n",
    "        return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instantiate the Neural Modules\n",
    "\n",
    "We now instantiate the neural modules and the encoder and decoder, set the weights of these models with the downloaded pretrained weights and construct the DAG to evaluate MatchboxNet on audio streams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate necessary neural modules\n",
    "data_layer = AudioDataLayer(sample_rate=model_definition['sample_rate'])\n",
    "\n",
    "data_preprocessor = nemo_asr.AudioToMFCCPreprocessor(\n",
    "    **model_definition['AudioToMFCCPreprocessor'])\n",
    "\n",
    "jasper_encoder = nemo_asr.JasperEncoder(\n",
    "    **model_definition['JasperEncoder'])\n",
    "\n",
    "jasper_decoder = nemo_asr.JasperDecoderForClassification(\n",
    "    feat_in=model_definition['JasperEncoder']['jasper'][-1]['filters'],\n",
    "    num_classes=len(model_definition['labels']))\n",
    "\n",
    "# load pre-trained model\n",
    "jasper_encoder.restore_from(CHECKPOINT_ENCODER)\n",
    "jasper_decoder.restore_from(CHECKPOINT_DECODER)\n",
    "\n",
    "# Define inference DAG\n",
    "audio_signal, audio_signal_len = data_layer()\n",
    "processed_signal, processed_signal_len = data_preprocessor(\n",
    "    input_signal=audio_signal,\n",
    "    length=audio_signal_len)\n",
    "encoded, encoded_len = jasper_encoder(audio_signal=processed_signal,\n",
    "                                      length=processed_signal_len)\n",
    "log_probs = jasper_decoder(encoder_output=encoded)\n",
    "\n",
    "# inference method for audio signal (single instance)\n",
    "def infer_signal(self, signal):\n",
    "    data_layer.set_signal(signal)\n",
    "    tensors = self.infer([log_probs], verbose=False)\n",
    "    logits = tensors[0][0]\n",
    "    return logits\n",
    "\n",
    "neural_factory.infer_signal = infer_signal.__get__(neural_factory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FrameASR: Helper class for streaming inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class for streaming frame-based ASR\n",
    "# 1) use reset() method to reset FrameASR's state\n",
    "# 2) call transcribe(frame) to do ASR on\n",
    "#    contiguous signal's frames\n",
    "class FrameASR:\n",
    "    \n",
    "    def __init__(self, neural_factory, model_definition,\n",
    "                 frame_len=2, frame_overlap=2.5, \n",
    "                 offset=10):\n",
    "        '''\n",
    "        Args:\n",
    "          frame_len: frame's duration, seconds\n",
    "          frame_overlap: duration of overlaps before and after current frame, seconds\n",
    "          offset: number of symbols to drop for smooth streaming\n",
    "        '''\n",
    "        self.vocab = list(model_definition['labels'])\n",
    "        self.vocab.append('_')\n",
    "        \n",
    "        self.sr = model_definition['sample_rate']\n",
    "        self.frame_len = frame_len\n",
    "        self.n_frame_len = int(frame_len * self.sr)\n",
    "        self.frame_overlap = frame_overlap\n",
    "        self.n_frame_overlap = int(frame_overlap * self.sr)\n",
    "        timestep_duration = model_definition['AudioToMFCCPreprocessor']['window_stride']\n",
    "        for block in model_definition['JasperEncoder']['jasper']:\n",
    "            timestep_duration *= block['stride'][0] ** block['repeat']\n",
    "        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len,\n",
    "                               dtype=np.float32)\n",
    "        self.offset = offset\n",
    "        self.reset()\n",
    "        \n",
    "    def _decode(self, frame, offset=0):\n",
    "        assert len(frame)==self.n_frame_len\n",
    "        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]\n",
    "        self.buffer[-self.n_frame_len:] = frame\n",
    "        logits = neural_factory.infer_signal(self.buffer).to('cpu').numpy()[0]\n",
    "        decoded = self._greedy_decoder(\n",
    "            logits, \n",
    "            self.vocab\n",
    "        )\n",
    "        return decoded[:len(decoded)-offset]\n",
    "    \n",
    "    def transcribe(self, frame=None):\n",
    "        if frame is None:\n",
    "            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)\n",
    "        if len(frame) < self.n_frame_len:\n",
    "            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')\n",
    "        unmerged = self._decode(frame, self.offset)\n",
    "        \n",
    "        return unmerged\n",
    "    \n",
    "    def reset(self):\n",
    "        '''\n",
    "        Reset frame_history and decoder's state\n",
    "        '''\n",
    "        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)\n",
    "        self.prev_char = ''\n",
    "\n",
    "    @staticmethod\n",
    "    def _greedy_decoder(logits, vocab):\n",
    "        s = ''\n",
    "        \n",
    "        if logits.shape[0]:\n",
    "            s += str(vocab[np.argmax(logits)]) + \"\\n\"\n",
    "            \n",
    "        return s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# duration of signal frame, seconds\n",
    "FRAME_LEN = 0.25\n",
    "# number of audio channels (expect mono signal)\n",
    "CHANNELS = 1\n",
    "# sample rate, Hz\n",
    "RATE = 16000\n",
    "\n",
    "CHUNK_SIZE = int(FRAME_LEN*RATE)\n",
    "asr = FrameASR(neural_factory, model_definition,\n",
    "               frame_len=FRAME_LEN, frame_overlap=2.0, \n",
    "               offset=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# What classes can this model recognize?\n",
    "\n",
    "Before we begin inference on the actual audio stream, lets look at what are the classes this model was trained to recognize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = model_definition['labels']\n",
    "\n",
    "for i in range(7):\n",
    "    for j in range(5):\n",
    "        print('%-10s' % (labels[i * 5 + j]), end=' ')\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Begin listening to audio stream and perform inference using FrameASR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = pa.PyAudio()\n",
    "print('Available audio input devices:')\n",
    "for i in range(p.get_device_count()):\n",
    "    dev = p.get_device_info_by_index(i)\n",
    "    if dev.get('maxInputChannels'):\n",
    "        print(i, dev.get('name'))\n",
    "print('Please type input device ID:')\n",
    "dev_idx = int(input())\n",
    "\n",
    "empty_counter = 0\n",
    "\n",
    "def callback(in_data, frame_count, time_info, status):\n",
    "    global empty_counter\n",
    "    signal = np.frombuffer(in_data, dtype=np.int16)\n",
    "    text = asr.transcribe(signal)\n",
    "    if len(text):\n",
    "        print(text,end='')\n",
    "        empty_counter = 3\n",
    "    elif empty_counter > 0:\n",
    "        empty_counter -= 1\n",
    "        if empty_counter == 0:\n",
    "            print(' ',end='')\n",
    "    return (in_data, pa.paContinue)\n",
    "\n",
    "stream = p.open(format=pa.paInt16,\n",
    "                channels=CHANNELS,\n",
    "                rate=RATE,\n",
    "                input=True,\n",
    "                input_device_index=dev_idx,\n",
    "                stream_callback=callback,\n",
    "                frames_per_buffer=CHUNK_SIZE)\n",
    "\n",
    "print('Listening...')\n",
    "\n",
    "stream.start_stream()\n",
    "\n",
    "while stream.is_active():\n",
    "    time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stream.stop_stream()\n",
    "stream.close()\n",
    "p.terminate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
